{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import sys\n",
    "if \"../\" not in sys.path:\n",
    "  sys.path.append(\"../\") \n",
    "from lib.utils import read_data_from_file, sign\n",
    "from lib.pla import pocket_pla"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear_reg_closed_form(X, y):\n",
    "    \"\"\"\n",
    "    Linear Regression Algorithm(Closed Form)\n",
    "    Args:\n",
    "        X: 数据\n",
    "        y: 预测值\n",
    "    Returns:\n",
    "        w_lin: 特征权重\n",
    "    \"\"\"    \n",
    "    X_pinv = np.linalg.pinv(X)\n",
    "    w_lin = X_pinv.dot(y)\n",
    "    return w_lin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_err_rate(X, y, w):\n",
    "    err_rate = (sign(X.dot(w)) != y).mean()\n",
    "    return err_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data_train shape:  (500, 5)\n",
      "data_test shape:  (500, 5)\n"
     ]
    }
   ],
   "source": [
    "# 数据读取\n",
    "data_train = read_data_from_file('hw1_18_train.dat') \n",
    "print('data_train shape: ', data_train.shape)\n",
    "data_test = read_data_from_file('hw1_18_test.dat')\n",
    "print('data_test shape: ', data_test.shape)\n",
    "\n",
    "y = data_train[:,-1]\n",
    "X = np.concatenate((np.ones((data_train.shape[0],1)), data_train[:,:-1]), axis=1)\n",
    "y_test = data_test[:,-1]\n",
    "X_test = np.concatenate((np.ones((data_test.shape[0],1)), data_test[:,:-1]), axis=1)\n",
    "\n",
    "times = 2000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# linear_reg_closed_form 测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear reg for classification:\n",
      "error rate: 0.10400000000000002\tcost: 0.8960509300231934\n"
     ]
    }
   ],
   "source": [
    "err_rates = []\n",
    "start = time.time()\n",
    "for i in range(times):\n",
    "    w_lin = linear_reg_closed_form(X, y)\n",
    "    err_rates.append(get_err_rate(X_test, y_test, w_lin))\n",
    "print('linear reg for classification:')\n",
    "print(\"error rate: {}\\tcost: {}\".format(np.mean(np.array(err_rates)),time.time()-start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# pocket_pla 测试,主要用于做对比"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pocket pla for classification:\n",
      "error rate: 0.13280200000000003\tcost: 19.535125494003296\n"
     ]
    }
   ],
   "source": [
    "err_rates = []\n",
    "start = time.time()\n",
    "for i in range(times):\n",
    "    w_pocket, _ = pocket_pla(X, y)\n",
    "    err_rates.append(get_err_rate(X_test, y_test, w_pocket))\n",
    "print('pocket pla for classification:')\n",
    "print(\"error rate: {}\\tcost: {}\".format(np.mean(np.array(err_rates)),time.time()-start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 初值为w_lin的pocket_pla"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pocket pla with w0=w_lin for classification:\n",
      "error rate: 0.10499300000000003\tcost: 19.429277181625366\n"
     ]
    }
   ],
   "source": [
    "err_rates = []\n",
    "start = time.time()\n",
    "w_lin = linear_reg_closed_form(X, y)\n",
    "for i in range(times):\n",
    "    w_pocket, _ = pocket_pla(X, y, w0=w_lin)\n",
    "    err_rates.append(get_err_rate(X_test, y_test, w_pocket))\n",
    "print('pocket pla with w0=w_lin for classification:')\n",
    "print(\"error rate: {}\\tcost: {}\".format(np.mean(np.array(err_rates)),time.time()-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "273.188px"
   },
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
